import transformers


class AdamY(transformers.AdamW):
    def __init__(self, module, lr=2e-5):
        parameters_to_optimize = list(module.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        parameters_to_optimize = [
            {'params': [p for n, p in parameters_to_optimize
                        if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in parameters_to_optimize
                        if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        super(AdamY, self).__init__(parameters_to_optimize, lr=lr, correct_bias=False)
